import numpy as np
import bchlib
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants
import torch
import cv2
import os
class issbaEncoder(object):
    def __init__(self,model_path, secret, size, session) :
        BCH_POLYNOMIAL = 137
        BCH_BITS = 5
        self.size = size
        self.sess = session
        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) 
        # ,config=tf.ConfigProto(gpu_options=gpu_options)
        
        # self.sess = tf.compat.v1.InteractiveSession(graph=tf.Graph())

        model = tf.compat.v1.saved_model.loader.load(self.sess, [tag_constants.SERVING], model_path)

        input_secret_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['secret'].name
        input_image_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['image'].name
        self.input_secret = tf.compat.v1.get_default_graph().get_tensor_by_name(input_secret_name)
        self.input_image = tf.compat.v1.get_default_graph().get_tensor_by_name(input_image_name)

        output_stegastamp_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['stegastamp'].name
        output_residual_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['residual'].name
        self.output_stegastamp = tf.compat.v1.get_default_graph().get_tensor_by_name(output_stegastamp_name)
        self.output_residual = tf.compat.v1.get_default_graph().get_tensor_by_name(output_residual_name)
        #bchlib.__version__<1.0.0
        bch = bchlib.BCH(BCH_POLYNOMIAL,BCH_BITS)
        #bchlib.___version__==1.0.0
        # bch = bchlib.BCH(t=BCH_BITS,prim_poly=BCH_POLYNOMIAL)
        if len(secret) > 7:
            print('Error: Can only encode 56bits (7 characters) with ECC')
            return

        data = bytearray(secret + ' '*(7-len(secret)), 'utf-8')
        ecc = bch.encode(data)
        packet = data + ecc

        packet_binary = ''.join(format(x, '08b') for x in packet)
        secret = [int(x) for x in packet_binary]
        secret.extend([0,0,0,0])
        self.secret = secret

    def __call__(self,image):
        input  = np.array(image, dtype=np.float32)
        input = np.transpose(input, (1, 2, 0))
        feed_dict = {self.input_secret:[self.secret],
                self.input_image:[input]}
        hidden_img, _= self.sess.run([self.output_stegastamp, self.output_residual],feed_dict=feed_dict)
        output= hidden_img[0]
        # output = residual[0]
        # residual = residual[0] + .5  # For visualization
        output = torch.tensor(np.transpose(output, (2, 0, 1)))
        return output
    def close(self):
        self.sess.close()
from flask import Flask, jsonify, request
import pickle
from queue import Queue
import argparse
app = Flask(__name__)

model_path = './stegastamp_pretrained'
secret='Stega!!'
size = (224,224)
queue = Queue(8)
for i in range(queue.maxsize): queue.put(1)

# encoder = issbaEncoder(model_path=model_path,secret=secret,size=size)

@app.route('/encodeImage', methods=['POST'])
def encode_img():
    queue.get()
    images_byte = request.data
    images_np = pickle.loads(images_byte)
    encoded_images = encoder(images_np)
    queue.put(1)
    return pickle.dumps(encoded_images)

 
 
if __name__ == '__main__':
    args = argparse.ArgumentParser()
    
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 
    args.add_argument('--port', default=9111, type=int)
    args = args.parse_args()
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = "2"  # 选择 GPU 设备编号，例如 0 表示第一块 GPU
    session = tf.Session(config=config)
    encoder = issbaEncoder(model_path=model_path, secret=secret, size=size, session=session)
    app.run(debug=False, host="127.0.0.1", port=args.port)
    encoder.close()